"""Phase 3b: Fix gravity section + time variation + quadrant regressions."""

import sys, os
import pandas as pd
import numpy as np
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

OUT = '/mnt/c/demographics_capital_flows/eu_demographics/output/tables'

panel = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel_with_resources.csv')
panel = panel[panel['year'] <= 2024].copy()
panel['commodity_exporter'] = (panel['resource_rents_gdp'] >= 10).astype(float)
panel['Z1_x_resource'] = panel['Z_1'] * panel['resource_rents_gdp']

controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
base_vars = ['Z_1', 'Z_2', 'Z_3'] + controls


def run_model(name, dep, indep, data, entity_col='iso3', quiet=False):
    est = data[[entity_col, 'year', dep] + indep].dropna()
    if len(est) < 50:
        print(f"  {name}: insufficient obs ({len(est)})")
        return None
    m = PanelGLS()
    m.fit(est[dep].values, est[indep].values, est[entity_col].values, est['year'].values)
    if not quiet:
        print(f"\n  {name}: N={m.n_obs}, R²={m.r_squared:.4f}")
        for i, v in enumerate(indep):
            s = '***' if m.pvalues[i]<0.01 else '**' if m.pvalues[i]<0.05 else '*' if m.pvalues[i]<0.1 else ''
            print(f"    {v:35s} {m.beta[i]:10.4f} ({m.se[i]:.4f}){s}")
    return {'model': name, 'n_obs': m.n_obs, 'r_squared': m.r_squared,
            'vars': indep, 'beta': m.beta, 'se': m.se, 'pvalues': m.pvalues}


# ═══════════════════════════════════════════════════════════════════════
# SECTION 7: BILATERAL FLOWS
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("SECTION 7: Bilateral Flows to/from Commodity Exporters")
print("=" * 70)

bilateral = pd.read_csv('/mnt/c/demographics_capital_flows/gravity_bilateral/data/processed/bilateral_panel.csv')
rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')

bilateral = bilateral.merge(
    rents[['iso3', 'year', 'resource_rents_gdp']].rename(
        columns={'iso3': 'iso_d', 'resource_rents_gdp': 'resource_rents_d'}),
    on=['iso_d', 'year'], how='left')
bilateral = bilateral.merge(
    rents[['iso3', 'year', 'resource_rents_gdp']].rename(
        columns={'iso3': 'iso_o', 'resource_rents_gdp': 'resource_rents_o'}),
    on=['iso_o', 'year'], how='left')

bilateral['comm_dest'] = (bilateral['resource_rents_d'] >= 10).astype(float)
bilateral['comm_origin'] = (bilateral['resource_rents_o'] >= 10).astype(float)
bilateral['dZ1_x_comm_dest'] = bilateral['dZ_1'] * bilateral['comm_dest']
bilateral['dZ1_x_comm_origin'] = bilateral['dZ_1'] * bilateral['comm_origin']

# Use pair_id as entity
if 'pair_id' not in bilateral.columns:
    bilateral['pair_id'] = bilateral['iso_o'] + '_' + bilateral['iso_d']

gravity_controls = [c for c in ['log_dist', 'contiguity', 'common_lang_official']
                    if c in bilateral.columns]

for fv in ['portfolio_total', 'portfolio_debt', 'fdi_outward']:
    if fv in bilateral.columns and bilateral[fv].notna().sum() > 500:
        bilateral[f'log_{fv}'] = np.log(bilateral[fv].clip(lower=1))
        dep = f'log_{fv}'

        print(f"\n--- {fv} ---")
        run_model(f'Gravity {fv}: dZ₁ baseline', dep,
                 ['dZ_1', 'dZ_2', 'dZ_3'] + gravity_controls,
                 bilateral, entity_col='pair_id')

        run_model(f'Gravity {fv}: + dZ₁ × comm_dest', dep,
                 ['dZ_1', 'dZ_2', 'dZ_3'] + gravity_controls +
                 ['comm_dest', 'dZ1_x_comm_dest'],
                 bilateral, entity_col='pair_id')

        run_model(f'Gravity {fv}: + dZ₁ × comm_origin', dep,
                 ['dZ_1', 'dZ_2', 'dZ_3'] + gravity_controls +
                 ['comm_origin', 'dZ1_x_comm_origin'],
                 bilateral, entity_col='pair_id')

# ═══════════════════════════════════════════════════════════════════════
# SECTION 8: TIME VARIATION
# ═══════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SECTION 8: Time Variation in Commodity-Demographic Interaction")
print("=" * 70)

for period, (y1, y2) in [('1980-1999', (1980, 1999)),
                          ('2000-2009', (2000, 2009)),
                          ('2010-2021', (2010, 2021))]:
    sub = panel[(panel['year'] >= y1) & (panel['year'] <= y2)]
    r = run_model(f'Period {period}', 'ca_gdp',
                  base_vars + ['resource_rents_gdp', 'Z1_x_resource'], sub, quiet=True)
    if r:
        z1_i = r['vars'].index('Z_1')
        int_i = r['vars'].index('Z1_x_resource')
        rr_i = r['vars'].index('resource_rents_gdp')
        def st(p): return '***' if p<0.01 else '**' if p<0.05 else '*' if p<0.1 else ''
        print(f"  {period}: N={r['n_obs']}, R²={r['r_squared']:.4f}, "
              f"Z₁={r['beta'][z1_i]:.3f}{st(r['pvalues'][z1_i])}, "
              f"Z₁×Res={r['beta'][int_i]:.4f}{st(r['pvalues'][int_i])}, "
              f"Rents={r['beta'][rr_i]:.3f}{st(r['pvalues'][rr_i])}")

# ═══════════════════════════════════════════════════════════════════════
# SECTION 9: QUADRANT REGRESSIONS
# ═══════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SECTION 9: Country Classification — Demographic-Commodity Matrix")
print("=" * 70)

recent = panel[(panel['year'] >= 2015) & (panel['year'] <= 2021)]
cc = recent.groupby('iso3').agg(
    mean_Z1=('Z_1', 'mean'), mean_rents=('resource_rents_gdp', 'mean'),
    mean_ca=('ca_gdp', 'mean')
).dropna(subset=['mean_Z1', 'mean_rents'])

cc['old'] = cc['mean_Z1'] > 0
cc['commodity'] = cc['mean_rents'] >= 10

quadrants = {
    (True, True): 'Old Commodity',
    (True, False): 'Old Non-Commodity',
    (False, True): 'Young Commodity',
    (False, False): 'Young Non-Commodity'
}

cc['quadrant'] = cc.apply(lambda r: quadrants[(r['old'], r['commodity'])], axis=1)

print("\n  Quadrant summary:")
for q in quadrants.values():
    sub = cc[cc['quadrant'] == q]
    countries = sub.index.tolist()
    sub_panel = panel[panel['iso3'].isin(countries)]

    r = run_model(q, 'ca_gdp', base_vars, sub_panel, quiet=True)
    if r:
        z1_i = r['vars'].index('Z_1')
        def st(p): return '***' if p<0.01 else '**' if p<0.05 else '*' if p<0.1 else ''
        print(f"\n  {q:25s}: {len(countries):3d} countries, N={r['n_obs']}, "
              f"Z₁={r['beta'][z1_i]:.3f}{st(r['pvalues'][z1_i])}, R²={r['r_squared']:.4f}")
        print(f"    Mean CA: {sub['mean_ca'].mean():.2f}, Mean Rents: {sub['mean_rents'].mean():.1f}%")
        if len(countries) <= 20:
            print(f"    Countries: {', '.join(sorted(countries))}")
        else:
            print(f"    Example: {', '.join(sorted(countries)[:10])}...")
    else:
        print(f"\n  {q:25s}: {len(countries):3d} countries — regression failed")

# ═══════════════════════════════════════════════════════════════════════
# SECTION 10: OIL EXPORTER DEEP DIVE — Gulf + Russia + Norway
# ═══════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SECTION 10: Oil Exporter Profiles")
print("=" * 70)

oil_countries = ['SAU', 'KWT', 'ARE', 'QAT', 'OMN', 'BHR',  # Gulf
                 'RUS', 'NOR',  # Major non-Gulf
                 'IRN', 'IRQ', 'DZA', 'LBY', 'AGO', 'NGA',  # Other major
                 'KAZ', 'AZE', 'TKM', 'VEN']

snap = panel[(panel['year'] >= 2018) & (panel['year'] <= 2021)]
oil_snap = snap[snap['iso3'].isin(oil_countries)].groupby('iso3').agg(
    Z1=('Z_1', 'mean'),
    OADR=('old_dep', 'mean'),
    CA=('ca_gdp', 'mean'),
    NFA=('nfa_gdp', 'mean'),
    Rents=('resource_rents_gdp', 'mean'),
    Fiscal=('fiscal_bal_gdp', 'mean'),
    Savings=('gross_savings_gdp', 'mean') if 'gross_savings_gdp' in snap.columns else ('ca_gdp', 'count')
)

# Add projections
full_proj = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
for yr in [2030, 2050]:
    proj = full_proj[full_proj['year'] == yr][['iso3', 'Z_1', 'old_dep']].rename(
        columns={'Z_1': f'Z1_{yr}', 'old_dep': f'OADR_{yr}'})
    oil_snap = oil_snap.merge(proj, left_index=True, right_on='iso3').set_index('iso3')

oil_snap = oil_snap.sort_values('Rents', ascending=False)

print("\nOil exporter demographic-commodity profiles:")
print(f"{'Country':6s} {'Rents%':>7s} {'Z₁':>7s} {'Z₁_30':>7s} {'Z₁_50':>7s} "
      f"{'OADR':>6s} {'OADR_50':>8s} {'CA':>6s} {'NFA':>6s} {'Fiscal':>7s}")
print("-" * 78)
for iso, row in oil_snap.iterrows():
    print(f"{iso:6s} {row['Rents']:7.1f} {row['Z1']:7.3f} {row.get('Z1_2030', np.nan):7.3f} "
          f"{row.get('Z1_2050', np.nan):7.3f} {row['OADR']*100:5.1f}% "
          f"{row.get('OADR_2050', np.nan)*100:7.1f}% {row['CA']:6.1f} "
          f"{row['NFA']:6.1f} {row['Fiscal']:7.1f}")

print("\n✓ Phase 3b complete.")
